{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "expired-priority",
   "metadata": {},
   "source": [
    "# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "becoming-television",
   "metadata": {},
   "source": [
    "This notebook gives a minimal example usage of SPLADE.\n",
    "\n",
    "* We provide models via Hugging Face (https://huggingface.co/naver)\n",
    "* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for other intermediate models.\n",
    "\n",
    "| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | \n",
    "| --- | --- | --- | --- | --- | --- |\n",
    "| `naver/splade_v2_max` (**v2** [HF](https://huggingface.co/naver/splade_v2_max)) | 34.0 | 96.5 | 1.32 | 18 | 92 |\n",
    "| `naver/splade_v2_distil` (**v2** [HF](https://huggingface.co/naver/splade_v2_distil)) | 36.8 | 97.9 | 3.82 | 25 | 232 |\n",
    "| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |\n",
    "| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "tight-milton",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForMaskedLM, AutoTokenizer\n",
    "from splade.models.transformer_rep import Splade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "mighty-invention",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the dir for trained weights\n",
    "\n",
    "##### v2\n",
    "# model_type_or_dir = \"naver/splade_v2_max\"\n",
    "# model_type_or_dir = \"naver/splade_v2_distil\"\n",
    "\n",
    "### v2bis, directly download from Hugging Face\n",
    "# model_type_or_dir = \"naver/splade-cocondenser-selfdistil\"\n",
    "model_type_or_dir = \"naver/splade-cocondenser-ensembledistil\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "centered-watershed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c2ed950bbbf4b878f40ffa6a26e3724",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e4d5ad49d3e4c869f7fea707b706c57",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aab7cd5cf970425fabde9d5ccfc26d4d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/466 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "106e18eff08d408297f4201a1397e561",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "44064282cb2c4c729fbc9cbab82c4d85",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1eef36ee53684c0288c09099917ef2cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# loading model and tokenizer\n",
    "\n",
    "model = Splade(model_type_or_dir, agg=\"max\")\n",
    "model.eval()\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)\n",
    "reverse_voc = {v: k for k, v in tokenizer.vocab.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "forward-movement",
   "metadata": {},
   "outputs": [],
   "source": [
    "# example document from MS MARCO passage collection (doc_id = 8003157)\n",
    "\n",
    "doc = \"Glass and Thermal Stress. Thermal Stress is created when one area of a glass pane gets hotter than an adjacent area. If the stress is too great then the glass will crack. The stress level at which the glass will break is governed by several factors.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "rolled-canon",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of actual dimensions:  126\n",
      "SPLADE BOW rep:\n",
      " [('stress', 2.25), ('glass', 2.23), ('thermal', 2.18), ('glasses', 1.65), ('pan', 1.62), ('heat', 1.56), ('stressed', 1.42), ('crack', 1.31), ('break', 1.12), ('cracked', 1.1), ('hot', 0.93), ('created', 0.9), ('factors', 0.81), ('broken', 0.73), ('caused', 0.71), ('too', 0.71), ('damage', 0.69), ('if', 0.68), ('hotter', 0.65), ('governed', 0.61), ('heating', 0.59), ('temperature', 0.59), ('adjacent', 0.59), ('cause', 0.58), ('effect', 0.57), ('fracture', 0.56), ('bradford', 0.55), ('strain', 0.53), ('hammer', 0.51), ('brian', 0.48), ('error', 0.47), ('windows', 0.45), ('will', 0.45), ('reaction', 0.42), ('create', 0.42), ('windshield', 0.41), ('heated', 0.41), ('factor', 0.4), ('cracking', 0.39), ('failure', 0.38), ('mechanical', 0.38), ('when', 0.38), ('formed', 0.38), ('bolt', 0.38), ('mechanism', 0.37), ('warm', 0.37), ('areas', 0.36), ('area', 0.36), ('energy', 0.34), ('disorder', 0.33), ('barry', 0.33), ('shock', 0.32), ('determined', 0.32), ('gage', 0.32), ('sash', 0.31), ('theory', 0.31), ('level', 0.31), ('resistant', 0.31), ('brake', 0.3), ('window', 0.3), ('crash', 0.3), ('hazard', 0.29), ('##ink', 0.27), ('ceramic', 0.27), ('storm', 0.25), ('problem', 0.25), ('issue', 0.24), ('impact', 0.24), ('fridge', 0.24), ('injury', 0.23), ('ross', 0.22), ('causes', 0.22), ('affect', 0.21), ('pressure', 0.21), ('fatigue', 0.21), ('leak', 0.21), ('eye', 0.2), ('frank', 0.2), ('cool', 0.2), ('might', 0.19), ('gravity', 0.18), ('ray', 0.18), ('static', 0.18), ('collapse', 0.18), ('physics', 0.18), ('wave', 0.18), ('reflection', 0.17), ('parker', 0.17), ('strike', 0.17), ('hottest', 0.17), ('burst', 0.16), ('chance', 0.16), ('burn', 0.14), ('rubbing', 0.14), ('interference', 0.14), ('bailey', 0.13), ('vibration', 0.12), ('gilbert', 0.12), ('produced', 0.12), ('rock', 0.12), ('warmer', 0.11), ('get', 0.11), ('drink', 0.11), ('fireplace', 0.11), ('ruin', 0.1), ('brittle', 0.1), ('fragment', 0.1), ('stumble', 0.09), ('formation', 0.09), ('shatter', 0.08), ('great', 0.08), ('friction', 0.08), ('flash', 0.07), ('cracks', 0.07), ('levels', 0.07), ('smash', 0.04), ('fail', 0.04), ('fra', 0.04), ('##glass', 0.03), ('variables', 0.03), ('because', 0.02), ('knock', 0.02), ('sun', 0.02), ('crush', 0.01), ('##e', 0.01), ('anger', 0.01)]\n"
     ]
    }
   ],
   "source": [
    "# now compute the document representation\n",
    "with torch.no_grad():\n",
    "    doc_rep = model(d_kwargs=tokenizer(doc, return_tensors=\"pt\"))[\"d_rep\"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)\n",
    "\n",
    "# get the number of non-zero dimensions in the rep:\n",
    "col = torch.nonzero(doc_rep).squeeze().cpu().tolist()\n",
    "print(\"number of actual dimensions: \", len(col))\n",
    "\n",
    "# now let's inspect the bow representation:\n",
    "weights = doc_rep[col].cpu().tolist()\n",
    "d = {k: v for k, v in zip(col, weights)}\n",
    "sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}\n",
    "bow_rep = []\n",
    "for k, v in sorted_d.items():\n",
    "    bow_rep.append((reverse_voc[k], round(v, 2)))\n",
    "print(\"SPLADE BOW rep:\\n\", bow_rep)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
